#include "config.h"
#include <stdio.h>
#include "params/params_inc.h"
#include "params/img_in_6.h"

// dtype img_in[1][32][32]={0.0f};
dtype c1_dout[6][28][28]={0.0f};
dtype c2_din[6][14][14]={0.0f};
dtype c2_dout[16][10][10]={0.0f};
dtype fc1_din[400];
dtype fc1_dout[120];
dtype fc2_dout[84];
dtype fc3_dout[10];

int main(int argc, char const *argv[])
{
    /* code */
    lenet_c1(
        (dtype (*)[32][32])     img_in,
        (dtype (*)[1][5][5])    c1_weight,
        (dtype (*))             c1_bias,
        (dtype (*)[28][28])     c1_dout);
    lenet_mp1(
        (dtype (*)[28][28])c1_dout,
        (dtype (*)[14][14])c2_din);
    lenet_c2(
        (dtype (*)[14][14])     c2_din,
        (dtype (*)[6][5][5])    c2_weight,
        (dtype (*))             c2_bias,
        (dtype (*)[10][10])     c2_dout);
    lenet_mp2(
        (dtype (*)[10][10])c2_dout,
        (dtype (*)[5][5])fc1_din);
    lenet_fc1(
        (dtype (*))fc1_din,
        (dtype (*))fc1_dout,
        (dtype (*)[400])fc1_weight,
        (dtype (*))fc1_bias);
    lenet_fc2(
        (dtype (*))fc1_dout,
        (dtype (*))fc2_dout,
        (dtype (*)[120])fc2_weight,
        (dtype (*))fc2_bias);
    lenet_fc3(
        (dtype (*))fc2_dout,
        (dtype (*))fc3_dout,
        (dtype (*)[84])fc3_weight,
        (dtype (*))fc3_bias);

    int arg_max = 0;
    dtype max_val = 0;
    for (int i = 0; i < 10; i++){
        if (max_val < fc3_dout[i]){
            max_val = fc3_dout[i];
            arg_max = i;
        }
        printf("%f\n", fc3_dout[i]);
    }
    printf("arg max = %d\n", arg_max);
    
    return 0;
}

